import functools
import inspect
import random
from abc import ABC
from collections.abc import Awaitable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from types import MappingProxyType
from typing import Any, Callable, Optional, Union, cast

from typing_extensions import Protocol, TypedDict, TypeGuard, runtime_checkable

from phoenix.client.__generated__ import v1

# Type aliases
JSONSerializable = Optional[Union[dict[str, Any], list[Any], str, int, float, bool]]
ExperimentId = str
DatasetId = str
DatasetVersionId = str
ExampleId = str
RepetitionNumber = int
ExperimentRunId = str
TraceId = str
TaskOutput = JSONSerializable
ExampleOutput = Mapping[str, JSONSerializable]
ExampleMetadata = Mapping[str, JSONSerializable]
ExampleInput = Mapping[str, JSONSerializable]
Score = Optional[Union[bool, int, float]]
Label = Optional[str]
Explanation = Optional[str]
EvaluatorName = str
EvaluatorKind = str

DRY_RUN = "DRY_RUN"

# Use autogenerated types
Experiment = v1.Experiment
ExperimentRun = v1.ExperimentRun


class ExperimentEvaluation(v1.ExperimentEvaluationResult, total=False):
    name: Optional[str]
    metadata: Mapping[str, Any]


EvaluationResult = Union[ExperimentEvaluation, Sequence[ExperimentEvaluation]]


def is_evaluation_result(obj: Any) -> TypeGuard[EvaluationResult]:
    if is_experiment_evaluation(obj):
        return True
    if isinstance(obj, Sequence):
        for item in obj:  # pyright: ignore[reportUnknownVariableType]
            if not is_experiment_evaluation(item):
                return False
        return True
    return False


def is_experiment_evaluation(obj: Any) -> TypeGuard[ExperimentEvaluation]:
    if not isinstance(obj, dict):
        return False
    data = cast(dict[str, Any], obj)
    label = data.get("label")
    if label is not None and not isinstance(label, str):
        return False
    score = data.get("score")
    if score is not None and not isinstance(score, (int, float)):
        return False
    name = data.get("name")
    if name is not None and not isinstance(name, str):
        return False
    metadata = data.get("metadata")
    if metadata is not None and not isinstance(metadata, Mapping):
        return False
    return True


class AnnotatorKind(Enum):
    CODE = "CODE"
    LLM = "LLM"


def _dry_run_id() -> str:
    suffix = random.getrandbits(24).to_bytes(3, "big").hex()
    return f"{DRY_RUN}_{suffix}"


@runtime_checkable
class EvaluationScore(Protocol):
    name: Optional[str]
    score: Optional[Union[float, int]]
    label: Optional[str]
    explanation: Optional[str]
    metadata: dict[str, Any]
    source: str
    direction: str


ScoreResult = Union[EvaluationScore, Sequence[EvaluationScore]]


def is_score_result(obj: Any) -> TypeGuard[ScoreResult]:
    return (
        isinstance(obj, EvaluationScore)
        or isinstance(obj, Sequence)
        and all(isinstance(item, EvaluationScore) for item in obj)  # pyright: ignore[reportUnknownVariableType]
    )


@runtime_checkable
class EvalsEvaluator(Protocol):
    def evaluate(self, input: dict[str, Any]) -> ScoreResult: ...

    async def async_evaluate(self, input: dict[str, Any]) -> ScoreResult: ...

    input_schema: Any
    direction: str
    source: str
    name: str


def is_evals_evaluator(obj: Any) -> TypeGuard[EvalsEvaluator]:
    return isinstance(obj, EvalsEvaluator)


@dataclass(frozen=True)
class TestCase:
    example: v1.DatasetExample
    repetition_number: RepetitionNumber


@dataclass(frozen=True)
class ExperimentEvaluationRun:
    experiment_run_id: ExperimentRunId
    start_time: datetime
    end_time: datetime
    name: str
    annotator_kind: str
    error: Optional[str] = None
    result: Optional[EvaluationResult] = None
    id: str = field(default_factory=_dry_run_id)
    trace_id: Optional[TraceId] = None
    metadata: Mapping[str, JSONSerializable] = field(
        default_factory=lambda: cast(dict[str, JSONSerializable], {})
    )

    def __post_init__(self) -> None:
        if self.result is None and self.error is None:
            raise ValueError("Must specify either result or error")


# Task and Evaluator types
ExperimentTask = Union[
    Callable[[v1.DatasetExample], TaskOutput],
    Callable[[v1.DatasetExample], Awaitable[TaskOutput]],
    Callable[..., JSONSerializable],
    Callable[..., Awaitable[JSONSerializable]],
]

EvaluatorOutput = Union[
    EvaluationResult,
    ScoreResult,
    bool,
    int,
    float,
    str,
    tuple[Score, Label, Explanation],
    dict[str, Any],
]


@runtime_checkable
class Evaluator(Protocol):
    """
    Protocol for evaluators that can score experiment outputs.

    Any object implementing this protocol can be used as an evaluator.
    Subclasses should implement either the `evaluate` or `async_evaluate` method.
    Implementing both methods is recommended, but not required.
    """

    @property
    def name(self) -> str:
        """The name of the evaluator."""
        ...

    @property
    def kind(self) -> str:
        """The kind of evaluator (e.g., 'CODE', 'LLM')."""
        ...

    def evaluate(
        self,
        *,
        output: Optional[TaskOutput] = None,
        expected: Optional[ExampleOutput] = None,
        metadata: ExampleMetadata = MappingProxyType({}),
        input: ExampleInput = MappingProxyType({}),
        **kwargs: Any,
    ) -> EvaluationResult:
        """Evaluate the output synchronously."""
        ...

    async def async_evaluate(
        self,
        *,
        output: Optional[TaskOutput] = None,
        expected: Optional[ExampleOutput] = None,
        metadata: ExampleMetadata = MappingProxyType({}),
        input: ExampleInput = MappingProxyType({}),
        **kwargs: Any,
    ) -> EvaluationResult:
        """Evaluate the output asynchronously."""
        ...


def _validate_evaluator_signature(sig: inspect.Signature) -> None:
    """Validate that a function signature is compatible with evaluator requirements."""
    params = sig.parameters
    valid_named_params = {"input", "output", "expected", "reference", "metadata", "example"}
    if len(params) == 0:
        raise ValueError("Evaluator function must have at least one parameter.")
    if len(params) > 1:
        for param_name in set(params) - valid_named_params:
            param = params[param_name]
            if (
                param.kind is inspect.Parameter.VAR_KEYWORD
                or param.default is not inspect.Parameter.empty
            ):
                continue
            raise ValueError(
                f"Invalid parameter name in evaluator function: {param_name}. "
                "Parameter names for multi-argument functions must be "
                f"any of: {', '.join(valid_named_params)}."
            )


def _validate_evaluator_method_signature(fn: Callable[..., Any], fn_name: str) -> None:
    """Validate that an evaluator method has the correct signature."""
    sig = inspect.signature(fn)
    _validate_evaluator_signature(sig)
    for param in sig.parameters.values():
        if param.kind is inspect.Parameter.VAR_KEYWORD:
            return
    else:
        raise ValueError(f"`{fn_name}` should allow variadic keyword arguments `**kwargs`")


class BaseEvaluator(ABC, Evaluator):
    """
    A helper abstract class to guide the implementation of an `Evaluator` object.
    Subclasses must implement either the `evaluate` or `async_evaluate` method.
    Implementing both methods is recommended, but not required.

    This Class is intended to be subclassed, and should not be instantiated directly.
    """

    _kind: AnnotatorKind
    _name: EvaluatorName

    @property
    def name(self) -> EvaluatorName:
        if hasattr(self, "_name"):
            return self._name
        return self.__class__.__name__

    @property
    def kind(self) -> EvaluatorKind:
        if hasattr(self, "_kind"):
            return self._kind.value
        return AnnotatorKind.CODE.value

    def __new__(cls, *args: Any, **kwargs: Any) -> "BaseEvaluator":
        if cls is BaseEvaluator:
            raise TypeError(f"{cls.__name__} is an abstract class and should not be instantiated.")
        return object.__new__(cls)

    def evaluate(
        self,
        *,
        output: Optional[TaskOutput] = None,
        expected: Optional[ExampleOutput] = None,
        metadata: ExampleMetadata = MappingProxyType({}),
        input: ExampleInput = MappingProxyType({}),
        **kwargs: Any,
    ) -> EvaluationResult:
        """
        Evaluate the output synchronously.

        For subclassing, one should implement either this sync method or the
        async version. Implementing both is recommended but not required.
        """
        raise NotImplementedError

    async def async_evaluate(
        self,
        *,
        output: Optional[TaskOutput] = None,
        expected: Optional[ExampleOutput] = None,
        metadata: ExampleMetadata = MappingProxyType({}),
        input: ExampleInput = MappingProxyType({}),
        **kwargs: Any,
    ) -> EvaluationResult:
        """
        Evaluate the output asynchronously.

        For subclassing, one should implement either this async method or the
        sync version. Implementing both is recommended but not required.
        """
        return self.evaluate(
            output=output,
            expected=expected,
            metadata=metadata,
            input=input,
            **kwargs,
        )

    def __init_subclass__(cls, **kwargs: Any) -> None:
        super().__init_subclass__(**kwargs)

        # Skip validation for abstract classes
        if getattr(cls, "__abstract__", False):
            return

        # Validate that subclass implements at least one evaluation method
        evaluate_fn_signature = inspect.signature(BaseEvaluator.evaluate)
        for super_cls in inspect.getmro(cls):
            if super_cls is BaseEvaluator:
                break
            if evaluate := super_cls.__dict__.get(BaseEvaluator.evaluate.__name__):  # pyright: ignore[reportUnknownVariableType]
                # pyright: ignore[reportUnknownVariableType] - suppress type warnings for evaluate variable
                if isinstance(evaluate, classmethod):
                    evaluate = evaluate.__func__  # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
                evaluate = cast(Callable[..., Any], evaluate)  # pyright: ignore[reportUnknownVariableType] - Cast to fix pyright warnings
                assert callable(evaluate), "`evaluate()` method should be callable"
                # need to remove the first param, i.e. `self`
                _validate_evaluator_method_signature(functools.partial(evaluate, None), "evaluate")  # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType]
                return
            if async_evaluate := super_cls.__dict__.get(BaseEvaluator.async_evaluate.__name__):  # pyright: ignore[reportUnknownVariableType]
                if isinstance(async_evaluate, classmethod):
                    async_evaluate = async_evaluate.__func__  # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
                async_evaluate = cast(
                    Callable[..., Any], async_evaluate
                )  # Cast to fix pyright warnings
                assert callable(async_evaluate), "`async_evaluate()` method should be callable"
                # need to remove the first param, i.e. `self`
                _validate_evaluator_method_signature(
                    functools.partial(async_evaluate, None),
                    "async_evaluate",  # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType]
                )
                return

        raise ValueError(
            f"Evaluator must implement either "
            f"`def evaluate{evaluate_fn_signature}` or "
            f"`async def async_evaluate{evaluate_fn_signature}`"
        )


class RanExperiment(TypedDict):
    """
    A completed experiment with its results.

    This represents an experiment that has been run and contains both the experiment
    metadata and the task runs. It can be used as input to evaluate_experiment to
    add additional evaluations.
    """

    experiment_id: ExperimentId
    dataset_id: DatasetId
    dataset_version_id: DatasetVersionId
    task_runs: list[ExperimentRun]
    evaluation_runs: list[ExperimentEvaluationRun]
    experiment_metadata: Mapping[str, Any]
    project_name: Optional[str]


class ExampleProxy(Mapping[str, Any]):
    """Immutable proxy for backward compatibility with legacy Example dataclass interface.

    This proxy bridges the gap between the new v1.DatasetExample TypedDict format
    and the legacy Example dataclass that user code expects. It provides both
    object-style attribute access (example.input) and dictionary-style access
    (example["input"]) while maintaining immutability.

    The proxy performs necessary type conversions to match the legacy interface:
    - updated_at: str (from API) → datetime (legacy interface)
    - Maintains Mapping[str, Any] interface for dictionary operations
    - Provides typed properties for id, updated_at, input, output, metadata

    This enables a seamless migration from the legacy Example dataclass to the
    new TypedDict-based API responses without breaking existing user code.

    Args:
        wrapped: The v1.DatasetExample TypedDict to wrap with legacy interface.

    Note:
        This class is immutable - all attempts to modify attributes will raise
        AttributeError. The wrapped data remains unchanged throughout the proxy's
        lifetime.
    """

    __slots__ = ("__wrapped__",)

    def __init__(self, wrapped: v1.DatasetExample) -> None:
        object.__setattr__(self, "__wrapped__", wrapped)

    @property
    def id(self) -> str:
        """Access to id field."""
        return self.__wrapped__["id"]  # type: ignore[no-any-return, attr-defined]

    @property
    def updated_at(self) -> datetime:
        """Access to updated_at field."""
        timestamp_str = self.__wrapped__["updated_at"]  # type: ignore[attr-defined]
        # Convert Z suffix to +00:00 for Python 3.9 compatibility
        if timestamp_str.endswith("Z"):
            timestamp_str = timestamp_str[:-1] + "+00:00"
        return datetime.fromisoformat(timestamp_str)

    @property
    def input(self) -> Mapping[str, Any]:
        """Access to input field."""
        return self.__wrapped__["input"]  # type: ignore[no-any-return, attr-defined]

    @property
    def output(self) -> Mapping[str, Any]:
        """Access to output field."""
        return self.__wrapped__["output"]  # type: ignore[no-any-return, attr-defined]

    @property
    def metadata(self) -> Mapping[str, Any]:
        """Access to metadata field."""
        return self.__wrapped__["metadata"]  # type: ignore[no-any-return, attr-defined]

    def __getitem__(self, key: str) -> Any:
        """Support dictionary-style access."""
        return self.__wrapped__[key]  # type: ignore[attr-defined]

    def get(self, key: str, default: Any = None) -> Any:
        """Dictionary-style get method with default value."""
        try:
            return self.__wrapped__[key]  # type: ignore[attr-defined]
        except KeyError:
            return default

    def __iter__(self) -> Iterator[str]:
        """Support iteration over dictionary keys."""
        return iter(self.__wrapped__)  # type: ignore[attr-defined]

    def __len__(self) -> int:
        """Support len() function for dictionary-style length."""
        return len(self.__wrapped__)  # type: ignore[attr-defined]

    def __setattr__(self, name: str, value: Any) -> None:
        """Prevent attribute assignment to maintain immutability."""
        raise AttributeError(f"'{type(self).__name__}' object is immutable")

    def __delattr__(self, name: str) -> None:
        """Prevent attribute deletion to maintain immutability."""
        raise AttributeError(f"'{type(self).__name__}' object is immutable")

    def __repr__(self) -> str:
        """Developer representation showing it's an immutable proxy."""
        return f"ExampleProxy({self.__wrapped__!r})"  # type: ignore[attr-defined]


# Type aliases for evaluators
ExperimentEvaluator = Union[
    Evaluator,
    EvalsEvaluator,
    Callable[..., EvaluatorOutput],
    Callable[..., Awaitable[EvaluatorOutput]],
]
ExperimentEvaluators = Union[
    ExperimentEvaluator,
    Sequence[ExperimentEvaluator],
    Mapping[EvaluatorName, ExperimentEvaluator],
]
RateLimitErrors = Union[type[BaseException], Sequence[type[BaseException]]]
